% Main Simulation: iid case + approximation

clear all
warning off
clc


% 0. parameter setup
B = 1000;
TT = [50 100 200];
NN = [100 200 300];
KK = [2 4 6];
T0 = 1;
sigu = 5;
sigy = [1 0.1];
Phi0 = diag([1:0.5:5])+diag(ones(1,8)*0.1,1)+diag(ones(1,8)*0.1,-1);
L = [5 10 20];
TUN = 0.1:0.1:1;

 
% 1. main estimation
tic
for t = 1:length(sigy)
    for s = 1:length(TT)
        T = TT(s);
        N = NN(s);    
        K = KK(s);
        parfor i = 1:B                     
            warning off
            % 2. ready DGP
            [y0,X0,~,w0] = DGP_gen2(N,T+T0,Phi0(1:K,1:K),sigu,sigy(t));    % add noise in DGP       
            Xt = X0(1:T,:);
            yt = y0(1:T,:);
            Xe = X0(T+1:end,:);
            ye = y0(T+1:end,:);
            % 3. forecasting
            f0 = Xe*w0;    
            f1 = mean(Xe,2);                    % simple average
            blasso = est_cv(yt,Xt,TUN,1);
            f2 = Xe*blasso;                     % LASSO
            bridge = est_cv(yt,Xt,TUN,2);        
            f3 = Xe*bridge;                     % RIDGE
            f4 = zeros(T0,length(L));           % PC
            for j = 1:length(L)
                f4(:,j) = pc_est(yt,Xt,Xe,K,L(j));
            end
            w = est_cv(yt,Xt,TUN,3);
            f5 = Xe*w;                           % L2RELAX no shrinkage
            w = est_cv(yt,Xt,TUN,4);
            f6 = Xe*w;                           % L2RELAX linear shrinkage
            w = est_cv(yt,Xt,TUN,5);
            f7 = Xe*w;                           % L2RELAX nonlinear shrinkage
            f = [f0 f1 f2 f3 f4 f5 f6 f7];            
            tmp = repmat(ye,1,size(f,2))-f;
            E(i,:) = mean(tmp.^2,1);
            E2(i,:) = mean(abs(tmp),1);
        end
        SAVE1(s,:) = mean(E,1); 
        SAVE2(s,:) = mean(E2,1);
    end
    SAVE{t} = [SAVE1;SAVE2];
end
toc


% 4. show results
% 4.1 MSFE
Rt = [];
for t = 1:length(SAVE)
    SAVEt = SAVE{t}(1:3,:);
    Rt = [Rt; SAVEt - repmat(sigy(t)^2,3,size(SAVEt,2))];
end
prt_tab([[TT TT; NN NN; KK KK]' Rt],{},3)    
% 4.2 MAFE
Rt = [];
for t = 1:length(SAVE)
    SAVEt = SAVE{t}(4:6,:);
    Rt = [Rt; SAVEt - repmat(2/sqrt(2*pi)*sigy(t),3,size(SAVEt,2))];
end
prt_tab([[TT TT; NN NN; KK KK]' Rt],{},3)  

%% Nested Functions
function [y,x,R2,w0,SUB] = DGP_gen2(N,T,Phi0,sigu,sigy)
    % 0. basic parameters
    K = size(Phi0,1);
    Nk = N/K;
    Phi = kron(Phi0,ones(Nk));
    w0 = kron(Phi0^(-1)*ones(K,1)/(ones(1,K)*Phi0^(-1)*ones(K,1)),ones(Nk,1)/Nk);
    Ou = sigu^2*eye(N);     % Omega_u
    SUB = w0'*(Ou-Ou*(Phi+Ou)^(-1)*Ou)*w0+sigy^2;
    R2 = 1-(SUB)/(w0'*Phi*w0+sigy^2);
    % 1. generate DGP
    eta = randn(T,N);
    u = randn(T,N)*sigu;
    uy = randn(T,1)*sigy;  
    c = 1;
    sig_n = sqrt(Nk^(-1/2)*c);
    x = eta*(sqrt(Phi)+sig_n*randn(N))+u;      % add extra noise
    y = eta*sqrt(Phi)*w0+uy;
end

function R = pc_est(yt,Xt,Xe,K,L)
% results for oracle estimator
    % 1. ready input matrix    
    et = repmat(yt,1,size(Xt,2)) - Xt;
    [~,~,V] = svd(et);
    V = V(1:L,:);
    IN = kmeans(V',K);    % divide into L clusters
    Xo = [];
    Fo = [];
    for i = 1:K
        int = find(IN==i);
        Xo = [Xo mean(Xt(:,int),2)];
        Fo = [Fo mean(Xe(:,int),2)];
    end
    % 2. obtain coefficient
    [T,k] = size(Xo);
    E = repmat(yt,1,k)-Xo;
    SIG = E'*E/T;
    l = ones(k,1);    
    w = (l'*SIG^(-1)*l)^(-1)*SIG^(-1)*l;
    R = Fo*w;
end
 




